from Libraries import *
def plot_2d_new(model,data, device='cpu'):
    x = torch.linspace(-5,5,1000,device=device)
    y = torch.linspace(-5,5,1000,device=device)
    inputs = torch.stack(torch.meshgrid(x, y)).view(2, -1).t()
    outputs = model(inputs) #classifiying the points
    red = inputs[outputs[:,0] <= outputs[:,1]]
    blue = inputs[outputs[:,0] > outputs[:,1]]
    plt.scatter(red[:,0].cpu(), red[:,1].cpu(), color='#fdcdac')
    plt.scatter(blue[:,0].cpu(), blue[:,1].cpu(), color='#b3cde3')
    plt.scatter(data[0],data[1],c='k')
    plt.xlim(-4,4)
    plt.ylim(-4,4)
    matplotlib.rcParams.update({'font.size': 30})
    plt.show()
    
def plot_polytope_2(zonotope1_vertices,hull1,zonotope2_vertices,hull2,legend = 'legend',legend2 = 'legend'):
    if zonotope1_vertices.shape[0] > 3:
        xp = zonotope1_vertices[hull1.vertices,0].tolist()
        yp = zonotope1_vertices[hull1.vertices,1].tolist()

        xp += [zonotope1_vertices[hull1.vertices,0][-1], zonotope1_vertices[hull1.vertices,0][0]]
        yp += [zonotope1_vertices[hull1.vertices,1][-1], zonotope1_vertices[hull1.vertices,1][0]]
        plt.plot(xp, yp, 'g-', lw=4)
        plt.grid()
    else:
        plt.plot(zonotope1_vertices[0],zonotope1_vertices[1],'g-')
    
    if zonotope2_vertices.shape[0] > 3:
        xp = zonotope2_vertices[hull2.vertices,0].tolist()
        yp = zonotope2_vertices[hull2.vertices,1].tolist()

        xp += [zonotope2_vertices[hull2.vertices,0][-1], zonotope2_vertices[hull2.vertices,0][0]]
        yp += [zonotope2_vertices[hull2.vertices,1][-1], zonotope2_vertices[hull2.vertices,1][0]]
        plt.plot(xp, yp, 'b--', lw=4)
        #plt.legend([legend , legend2],prop={'size': 40})
        plt.grid()
        plt.xlim(0,5)
        plt.ylim(0,5)
        #plt.xticks([])
        #plt.yticks([])
        matplotlib.rcParams.update({'font.size': 30})
        for simplex in hull2.simplices:#Before Compression
            plt.plot(zonotope2_vertices[simplex, 0], zonotope2_vertices[simplex, 1], 'ro')
        for simplex in hull1.simplices:#Before Compression
            plt.plot(zonotope1_vertices[simplex, 0], zonotope1_vertices[simplex, 1], 'ro')
    else:
        plt.plot(zonotope1_vertices[0],zonotope1_vertices[1],'b--')
